from tensorflow import keras
from keras import backend as K
from keras.engine.topology import Layer

class DssmBase():
    """docstring for TextCNNSmall"""
    def __init__(self, model_conf, train_conf={"batch_size":64,"epochs":3, "verbose":1}):
        self.model_conf = model_conf
        self.train_conf = train_conf
        self.__build_structure__()


    def __build_structure__(self):
        inputs_1 = keras.layers.Input(shape=(self.model_conf["MAX_LEN"],), name="inputs_1")
        inputs_2 = keras.layers.Input(shape=(self.model_conf["MAX_LEN"],), name="inputs_2")

        embedding_layer = keras.layers.Embedding(output_dim = self.model_conf["w2c_len"],
                                    input_dim = len(self.model_conf["emb_model"].embedding_weights), 
                                    weights=[self.model_conf["emb_model"].get_np_weights()], 
                                    input_length=self.model_conf["MAX_LEN"], 
                                    trainable=True,mask_zero=True
                                    )

        embedding_seq_1 = embedding_layer(inputs_1)
        embedding_seq_2 = embedding_layer(inputs_2)

        # embedding_seq_1 = keras.layers.Masking()(embedding_seq_1)
        # embedding_seq_2 = keras.layers.Masking()(embedding_seq_2)
        # flatten_1 = MyMeanPool(axis = 1)(embedding_seq_1)
        # flatten_2 = MyMeanPool(axis = 1)(embedding_seq_2)

        # embedding_seq_1 = keras.layers.Masking(mask_value=0)(embedding_seq_1)
        # embedding_seq_2 = keras.layers.Masking(mask_value=0)(embedding_seq_2)

        flatten_1 = keras.layers.Lambda(lambda x: keras.backend.mean(x, axis=1), name="flatten_1")(embedding_seq_1)
        flatten_2 = keras.layers.Lambda(lambda x: keras.backend.mean(x, axis=1), name="flatten_2")(embedding_seq_2)

        output = keras.layers.dot([flatten_1, flatten_2], axes=1)
        pred = keras.layers.Dense(units=1, activation='sigmoid')(output)

        self.model = keras.models.Model(inputs=[inputs_1, inputs_2], outputs=pred)
        self.model.summary()
        self.model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])


    def fit(self, x_train, y_train, x_test, y_test):
        history = self.model.fit(x_train, y_train, batch_size=self.train_conf.get("batch_size", 64),
                                epochs=self.train_conf.get("epochs", 3),
                                validation_data=(x_test, y_test),
                                verbose=self.train_conf.get("verbose", 1))
        self.model_vec_1 = keras.models.Model(inputs=self.model.get_layer('inputs_1').input, outputs=self.model.get_layer('flatten_1').output)
        self.model_vec_2 = keras.models.Model(inputs=self.model.get_layer('inputs_2').input, outputs=self.model.get_layer('flatten_2').output)
        return history

    def evaluate(self, x_test, y_test):
        return self.model.evaluate(x_test, y_test)

    def predict(self, sentences):
        return self.model.predict(sentences)

    def predict_vec(self, sentences):
        return self.model_vec_1.predict(sentences), self.model_vec_2.predict(sentences) 

    def save(self, path):
        if self.model:
            self.model.save(path)

    def load(self, path):
        self.model = keras.load_model(path)